{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 5 out of 7 words in the pre-training embedding.\n",
      "torch.Size([1, 5, 50])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from fastNLP.embeddings import StaticEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo .\".split())\n",
    "\n",
    "embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
    "\n",
    "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])  # 将文本转为index\n",
    "print(embed(words).size())  # StaticEmbedding的使用和pytorch的nn.Embedding是类似的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 5, 30])\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo .\".split())\n",
    "\n",
    "embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=30)\n",
    "\n",
    "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
    "print(embed(words).size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22 out of 22 characters were found in pretrained elmo embedding.\n",
      "torch.Size([1, 5, 256])\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import ElmoEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo .\".split())\n",
    "\n",
    "embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False)\n",
    "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
    "print(embed(words).size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22 out of 22 characters were found in pretrained elmo embedding.\n",
      "torch.Size([1, 5, 512])\n"
     ]
    }
   ],
   "source": [
    "embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=False, layers='1,2')\n",
    "print(embed(words).size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22 out of 22 characters were found in pretrained elmo embedding.\n",
      "torch.Size([1, 5, 256])\n"
     ]
    }
   ],
   "source": [
    "embed = ElmoEmbedding(vocab, model_dir_or_name='en-small', requires_grad=True, layers='mix')\n",
    "print(embed(words).size())  # 三层输出按照权重element-wise的加起来"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 7 words out of 7.\n",
      "torch.Size([1, 5, 768])\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import BertEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo .\".split())\n",
    "\n",
    "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased')\n",
    "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
    "print(embed(words).size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 7 words out of 7.\n",
      "torch.Size([1, 5, 1536])\n"
     ]
    }
   ],
   "source": [
    "#  使用后面两层的输出\n",
    "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='10,11')\n",
    "print(embed(words).size())  # 结果将是在最后一维做拼接"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 7 words out of 7.\n",
      "torch.Size([1, 7, 768])\n"
     ]
    }
   ],
   "source": [
    "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', include_cls_sep=True)\n",
    "print(embed(words).size())  # 结果将在序列维度上增加2\n",
    "# 取出句子的cls表示\n",
    "cls_reps = embed(words)[:, 0]  # shape: [batch_size, 768]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 7 words out of 7.\n",
      "torch.Size([1, 5, 768])\n"
     ]
    }
   ],
   "source": [
    "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')\n",
    "print(embed(words).size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 10 words out of 10.\n",
      "torch.Size([1, 9, 768])\n"
     ]
    }
   ],
   "source": [
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo . [SEP] another sentence .\".split())\n",
    "\n",
    "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', layers='-1', pool_method='max')\n",
    "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo . [SEP] another sentence .\".split()]])\n",
    "print(embed(words).size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start constructing character vocabulary.\n",
      "In total, there are 8 distinct characters.\n",
      "torch.Size([1, 5, 64])\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import CNNCharEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo .\".split())\n",
    "\n",
    "# character的embedding维度大小为50，返回的embedding结果维度大小为64。\n",
    "embed = CNNCharEmbedding(vocab, embed_size=64, char_emb_size=50)\n",
    "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
    "print(embed(words).size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start constructing character vocabulary.\n",
      "In total, there are 8 distinct characters.\n",
      "torch.Size([1, 5, 64])\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import LSTMCharEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo .\".split())\n",
    "\n",
    "# character的embedding维度大小为50，返回的embedding结果维度大小为64。\n",
    "embed = LSTMCharEmbedding(vocab, embed_size=64, char_emb_size=50)\n",
    "words = torch.LongTensor([[vocab.to_index(word) for word in \"this is a demo .\".split()]])\n",
    "print(embed(words).size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 5 out of 7 words in the pre-training embedding.\n",
      "50\n",
      "Start constructing character vocabulary.\n",
      "In total, there are 8 distinct characters.\n",
      "30\n",
      "22 out of 22 characters were found in pretrained elmo embedding.\n",
      "256\n",
      "22 out of 22 characters were found in pretrained elmo embedding.\n",
      "512\n",
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 7 words out of 7.\n",
      "768\n",
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 7 words out of 7.\n",
      "1536\n",
      "80\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import *\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo .\".split())\n",
    "\n",
    "static_embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d')\n",
    "print(static_embed.embedding_dim)  # 50\n",
    "char_embed = CNNCharEmbedding(vocab, embed_size=30)\n",
    "print(char_embed.embedding_dim)    # 30\n",
    "elmo_embed_1 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='2')\n",
    "print(elmo_embed_1.embedding_dim)  # 256\n",
    "elmo_embed_2 = ElmoEmbedding(vocab, model_dir_or_name='en-small', layers='1,2')\n",
    "print(elmo_embed_2.embedding_dim)  # 512\n",
    "bert_embed_1 = BertEmbedding(vocab, layers='-1', model_dir_or_name='en-base-cased')\n",
    "print(bert_embed_1.embedding_dim)  # 768\n",
    "bert_embed_2 = BertEmbedding(vocab, layers='2,-1', model_dir_or_name='en-base-cased')\n",
    "print(bert_embed_2.embedding_dim)  # 1536\n",
    "stack_embed = StackEmbedding([static_embed, char_embed])\n",
    "print(stack_embed.embedding_dim)  # 80"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading vocabulary file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/vocab.txt\n",
      "Load pre-trained BERT parameters from file /remote-home/ynzheng/.fastNLP/embedding/bert-base-cased/pytorch_model.bin.\n",
      "Start to generate word pieces for word.\n",
      "Found(Or segment into word pieces) 7 words out of 7.\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import *\n",
    "\n",
    "vocab = Vocabulary()\n",
    "vocab.add_word_lst(\"this is a demo .\".split())\n",
    "\n",
    "embed = BertEmbedding(vocab, model_dir_or_name='en-base-cased', requires_grad=True)  # 初始化时设定为需要更新\n",
    "embed.requires_grad = False  # 修改BertEmbedding的权重为不更新"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.3633, -0.2091, -0.0353, -0.3771, -0.5193]],\n",
      "       grad_fn=<EmbeddingBackward>)\n",
      "tensor([[ 0.0926, -0.4812, -0.7744,  0.4836, -0.5475]],\n",
      "       grad_fn=<EmbeddingBackward>)\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary().add_word_lst(\"The the a A\".split())\n",
    "#  下面用随机的StaticEmbedding演示，但与使用预训练词向量时效果是一致的\n",
    "embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5)\n",
    "print(embed(torch.LongTensor([vocab.to_index('The')])))\n",
    "print(embed(torch.LongTensor([vocab.to_index('the')])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All word in the vocab have been lowered. There are 6 words, 4 unique lowered words.\n",
      "tensor([[ 0.4530, -0.1558, -0.1941,  0.3203,  0.0355]],\n",
      "       grad_fn=<EmbeddingBackward>)\n",
      "tensor([[ 0.4530, -0.1558, -0.1941,  0.3203,  0.0355]],\n",
      "       grad_fn=<EmbeddingBackward>)\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary().add_word_lst(\"The the a A\".split())\n",
    "#  下面用随机的StaticEmbedding演示，但与使用预训练时效果是一致的\n",
    "embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, lower=True)\n",
    "print(embed(torch.LongTensor([vocab.to_index('The')])))\n",
    "print(embed(torch.LongTensor([vocab.to_index('the')])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 out of 4 words have frequency less than 2.\n",
      "tensor([[ 0.4724, -0.7277, -0.6350, -0.5258, -0.6063]],\n",
      "       grad_fn=<EmbeddingBackward>)\n",
      "tensor([[ 0.7638, -0.0552,  0.1625, -0.2210,  0.4993]],\n",
      "       grad_fn=<EmbeddingBackward>)\n",
      "tensor([[ 0.7638, -0.0552,  0.1625, -0.2210,  0.4993]],\n",
      "       grad_fn=<EmbeddingBackward>)\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary().add_word_lst(\"the the the a\".split())\n",
    "#  下面用随机的StaticEmbedding演示，但与使用预训练时效果是一致的\n",
    "embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2)\n",
    "print(embed(torch.LongTensor([vocab.to_index('the')])))\n",
    "print(embed(torch.LongTensor([vocab.to_index('a')])))\n",
    "print(embed(torch.LongTensor([vocab.unknown_idx])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 out of 5 words have frequency less than 2.\n",
      "All word in the vocab have been lowered. There are 5 words, 4 unique lowered words.\n",
      "tensor([[ 0.1943,  0.3739,  0.2769, -0.4746, -0.3181]],\n",
      "       grad_fn=<EmbeddingBackward>)\n",
      "tensor([[ 0.5892, -0.6916,  0.7319, -0.3803,  0.4979]],\n",
      "       grad_fn=<EmbeddingBackward>)\n",
      "tensor([[ 0.5892, -0.6916,  0.7319, -0.3803,  0.4979]],\n",
      "       grad_fn=<EmbeddingBackward>)\n",
      "tensor([[-0.1348, -0.2172, -0.0071,  0.5704, -0.2607]],\n",
      "       grad_fn=<EmbeddingBackward>)\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding\n",
    "from fastNLP import Vocabulary\n",
    "\n",
    "vocab = Vocabulary().add_word_lst(\"the the the a A\".split())\n",
    "#  下面用随机的StaticEmbedding演示，但与使用预训练时效果是一致的\n",
    "embed = StaticEmbedding(vocab, model_name_or_dir=None, embedding_dim=5, min_freq=2, lower=True)\n",
    "print(embed(torch.LongTensor([vocab.to_index('the')])))\n",
    "print(embed(torch.LongTensor([vocab.to_index('a')])))\n",
    "print(embed(torch.LongTensor([vocab.to_index('A')])))\n",
    "print(embed(torch.LongTensor([vocab.unknown_idx])))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python Now",
   "language": "python",
   "name": "now"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
